Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose quantities related to generic frames #148

Merged
merged 30 commits into from
May 23, 2024

Conversation

xela-95
Copy link
Member

@xela-95 xela-95 commented May 7, 2024

Closes #147


📚 Documentation preview 📚: https://jaxsim--148.org.readthedocs.build//148/

@xela-95 xela-95 changed the title Feature/expose frame quantities Expose quantities related to generic frames May 7, 2024
@xela-95 xela-95 marked this pull request as ready for review May 8, 2024 15:15
@xela-95
Copy link
Member Author

xela-95 commented May 8, 2024

The implementation needed by issue #147 have been completed. Now the unit test for the jacobian function of frame module is taking a lot to complete, but unfortunately I was not able to speed up the computations from the test by using jax.vmap since that would raise Jax errors like TracerArrayConversionError or TracerIntegerConversionError.

@diegoferigo @flferretti could you please review the math in the frame functions and the unit tests implementation?

@flferretti
Copy link
Collaborator

Thanks a lot @xela-95 for the PR! Could you please explain how did you get the errors using the vmap? It would be nice to be able to reproduce and eventually find a solution for it

src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
@xela-95
Copy link
Member Author

xela-95 commented May 8, 2024

Thanks a lot @xela-95 for the PR! Could you please explain how did you get the errors using the vmap? It would be nice to be able to reproduce and eventually find a solution for it

For example, if I try to mimick what is done to test the link jacobians in https://github.com/xela-95/jaxsim/blob/96c600dd2e7fc9a0136800924fde8bbe40d9f5cd/tests/test_api_link.py#L139-L141

but instead calling js.frame.jacobian like:

    J_WL_frames = jax.vmap(
        lambda idx: js.frame.jacobian(model=model, data=data, frame_index=idx)
    )(jnp.array(frame_indexes))

I get:

tests/test_api_frame.py:122: in <lambda>
    lambda idx: js.frame.jacobian(model=model, data=data, frame_index=idx)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

model = JaxSimModel(model_name='ergoCub')
data = JaxSimModelData(velocity_representation=<VelRepr.Inertial: 3>, state=ODEState(physics_model=PhysicsModelState(joint_po...loat64[3])>with<DynamicJaxprTrace(level=2/0)>, time_ns=Traced<ShapedArray(uint64[])>with<DynamicJaxprTrace(level=2/0)>)

    @functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"])
    def jacobian(
        model: js.model.JaxSimModel,
        data: js.data.JaxSimModelData,
        *,
        frame_index: jtp.IntLike,
        output_vel_repr: VelRepr | None = None,
    ) -> jtp.Matrix:
        """
        Compute the free-floating jacobian of the frame.
    
        Args:
            model: The model to consider.
            data: The data of the considered model.
            frame_index: The index of the frame.
            output_vel_repr:
                The output velocity representation of the free-floating jacobian.
    
        Returns:
            The 6×(6+n) free-floating jacobian of the frame.
    
        Note:
            The input representation of the free-floating jacobian is the active
            velocity representation.
        """
    
        output_vel_repr = (
            output_vel_repr if output_vel_repr is not None else data.velocity_representation
        )
    
        # Get the free-floating jacobian of the parent link in body-fixed output representation
        L = (
>           model.description.get()
            .frames[frame_index - model.number_of_links()]
            .parent.index
        )
E       jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
E       The error occurred while tracing the function jacobian at /home/acroci/repos/jaxsim/src/jaxsim/api/frame.py:127 for jit. This value became a tracer due to JAX operations on these lines:
E       
E         operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
E           from line /home/acroci/repos/jaxsim/src/jaxsim/api/frame.py:160:16 (jacobian)
E       
E         operation a:i64[] = sub b c
E           from line /home/acroci/repos/jaxsim/src/jaxsim/api/frame.py:160:16 (jacobian)
E       See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

src/jaxsim/api/frame.py:159: TracerIntegerConversionError

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks good!

For the records, JIT compilation is slow because it has to compile the frame.transform and frame.jacobian functions for each frame. After the first compilation, the compiled functions should work also on different instances of model and data.

Users, generally, don't need to calculate these quantities for all frames, contrarily to links. For the time being, JIT speed is not a major concern.

src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
@flferretti
Copy link
Collaborator

Thanks for providing an example. I'll take a look at it, in the meanwhile for me it is good to go

@xela-95
Copy link
Member Author

xela-95 commented May 8, 2024

Perfect, thanks for the clear explanations!!

@flferretti
Copy link
Collaborator

I just want to point out that from E operation a:i64[] = sub b c I get that the error is raised from the subtraction operation frame_index - model.number_of_links(), therefore we must focus on the type of frame_index when it gets passed to the function.

@xela-95
Copy link
Member Author

xela-95 commented May 8, 2024

I just want to point out that from E operation a:i64[] = sub b c I get that the error is raised from the subtraction operation frame_index - model.number_of_links(), therefore we must focus on the type of frame_index when it gets passed to the function.

Nice catch! I'll try to iterate on that point to find a solution exploiting Jax capabilities

@xela-95
Copy link
Member Author

xela-95 commented May 8, 2024

The CI is still failing with error Error: The operation was canceled.. Do you think it's due to some maximum time allowed for action runner?

Could you re-run one of the failing actions, enabling the debug logging to have more details? (I do not have the permissions) https://docs.github.com/en/actions/managing-workflow-runs/re-running-workflows-and-jobs

@xela-95 xela-95 force-pushed the feature/expose-frame-quantities branch from fcaaa94 to 52331b0 Compare May 10, 2024 07:45
@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

Brief recap of yesterday

This PR is not yet merged since modifying the unit test checking the match between the jacobian of frames computed in Jaxsim and iDynTree the test started to fail. Since the only model used in tests with frames (that are not links) is ergoCub, I needed easier examples to work on to debug the function computing the jacobian.

I started updating the simple box model (no joints) by adding a frame attached to the only link of the model using rod APIs: https://github.com/xela-95/jaxsim/blob/3d4261586efe4b09b1e0b30ff668a0aa2f99c5e7/tests/conftest.py#L130-L136

In this case the unit test passed.

Then, I wanted to use frames for more complex models, like the UR10 manipulator. These models are loaded using the robot-descriptions package and they already have in their URDF descriptions dummy frames, but they were not parsed by rod.

@diegoferigo added support to frames super quickly in #150. Unfortunately, the issue is that these frames are not loaded in KynDynComputations so it's not possible to use them for the unit test.

@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

Right now on ErgoCubReduced the jacobians that are not matching are the ones for the frames:

  • l_foot_front
  • l_foot_rear
  • l_hip_3
  • l_shoulder_3
  • r_hip_3
  • r_shoulder3

These test cases are failing in all 3 velocity representations (inertial, body and fixed). All the other frames are passing the test.

@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

Right now on ErgoCubReduced the jacobians that are not matching are the ones for the frames:

  • l_foot_front
  • l_foot_rear
  • l_hip_3
  • l_shoulder_3
  • r_hip_3
  • r_shoulder3

These test cases are failing in all 3 velocity representations (inertial, body and fixed). All the other frames are passing the test.

Ok I found out the the culprit is actually not the jacobian computation but the transform function of the frame, since after refactoring the test I saw it fails on the same frames. 🕵🏻

@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

Here's the log containg the homogeneous transforms from world to frame that are failing:

Logs
jaxsim[100706] ERROR Assertion failed for frame: l_foot_front
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.41084]
 [ 0.928   -0.34364 -0.14396  0.4807 ]
 [-0.12686 -0.65474  0.74513  0.39986]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.54756 -0.36036 -0.75519 -0.36317]
 [-0.15713 -0.84218  0.51579 -0.03562]
 [-0.82188  0.40109  0.40453  0.33124]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_foot_rear
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.45261]
 [ 0.928   -0.34364 -0.14396  0.37004]
 [-0.12686 -0.65474  0.74513  0.41499]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.54756 -0.36036 -0.75519 -0.29787]
 [-0.15713 -0.84218  0.51579 -0.01688]
 [-0.82188  0.40109  0.40453  0.42925]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_hip_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.80645 -0.10969 -0.58104 -0.02238]
 [ 0.46848 -0.71808 -0.51467  0.27463]
 [-0.36078 -0.68726  0.63049  0.81591]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.37434 -0.05964  0.92537 -0.24618]
 [ 0.00864 -0.99811 -0.06083  0.09552]
 [ 0.92725 -0.01478  0.37415  0.80239]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_shoulder_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.46391]
 [ 0.928   -0.34364 -0.14396  0.35207]
 [-0.12686 -0.65474  0.74513  0.4128 ]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.6807  -0.7098  -0.18119 -0.32363]
 [ 0.69106 -0.70426  0.16267  0.01637]
 [-0.24307 -0.01448  0.9699   1.14715]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: r_hip_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.33184 -0.18322  0.92537 -0.40775]
 [ 0.34707 -0.93586 -0.06083  0.13311]
 [ 0.87717  0.30098  0.37415  0.77404]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.80645 -0.10969 -0.58104 -0.02238]
 [ 0.46848 -0.71808 -0.51467  0.27463]
 [-0.36078 -0.68726  0.63049  0.81591]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: r_shoulder_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.46478]
 [ 0.928   -0.34364 -0.14396  0.34975]
 [-0.12686 -0.65474  0.74513  0.41312]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.91833 -0.10879 -0.38058 -0.17789]
 [ 0.21158 -0.94752 -0.23966  0.31045]
 [-0.33454 -0.30061  0.89315  1.20324]
 [ 0.       0.       0.       1.     ]]

@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

The math compute by the jacobian function is pretty simple:

  • Get the parent link transform ${}^WH_L$
  • Get the pose of the frame with respect to its parent link ${}^FH_L$ by accessing the pose attribute of the frame
    pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
  • Compute ${}^WH_F = {}^WH_L \cdot {}^LH_F$

Now, since this test is passing for the majority of the frames and failing only for 6 of them, what I'm suspecting is that maybe is not always true that the pose of the frame is relative to its parent link. What do you think @diegoferigo?

@traversaro
Copy link
Contributor

Right now on ErgoCubReduced the jacobians that are not matching are the ones for the frames:

* l_foot_front

* l_foot_rear

* l_hip_3

* l_shoulder_3

* r_hip_3

* r_shoulder3

These test cases are failing in all 3 velocity representations (inertial, body and fixed). All the other frames are passing the test.

The specificity of all this frames is that they are not leaf "fake link frames" but rather proper links (with an inertia) that are lumped to their parents. Perhaps the lumping is not working as expected either in iDynTree or rod? Do we have a check that simply checks forward kinematics for those frames, instead of checking the jacobian?

@traversaro
Copy link
Contributor

Here's the log containg the homogeneous trasnforms from world to frame that are failing:
Logs

jaxsim[100706] ERROR Assertion failed for frame: l_foot_front
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.41084]
 [ 0.928   -0.34364 -0.14396  0.4807 ]
 [-0.12686 -0.65474  0.74513  0.39986]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.54756 -0.36036 -0.75519 -0.36317]
 [-0.15713 -0.84218  0.51579 -0.03562]
 [-0.82188  0.40109  0.40453  0.33124]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_foot_rear
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.45261]
 [ 0.928   -0.34364 -0.14396  0.37004]
 [-0.12686 -0.65474  0.74513  0.41499]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.54756 -0.36036 -0.75519 -0.29787]
 [-0.15713 -0.84218  0.51579 -0.01688]
 [-0.82188  0.40109  0.40453  0.42925]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_hip_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.80645 -0.10969 -0.58104 -0.02238]
 [ 0.46848 -0.71808 -0.51467  0.27463]
 [-0.36078 -0.68726  0.63049  0.81591]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.37434 -0.05964  0.92537 -0.24618]
 [ 0.00864 -0.99811 -0.06083  0.09552]
 [ 0.92725 -0.01478  0.37415  0.80239]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: l_shoulder_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.46391]
 [ 0.928   -0.34364 -0.14396  0.35207]
 [-0.12686 -0.65474  0.74513  0.4128 ]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.6807  -0.7098  -0.18119 -0.32363]
 [ 0.69106 -0.70426  0.16267  0.01637]
 [-0.24307 -0.01448  0.9699   1.14715]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: r_hip_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.33184 -0.18322  0.92537 -0.40775]
 [ 0.34707 -0.93586 -0.06083  0.13311]
 [ 0.87717  0.30098  0.37415  0.77404]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.80645 -0.10969 -0.58104 -0.02238]
 [ 0.46848 -0.71808 -0.51467  0.27463]
 [-0.36078 -0.68726  0.63049  0.81591]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] ERROR Assertion failed for frame: r_shoulder_3
jaxsim[100706] DEBUG W_H_F_js:
jaxsim[100706] DEBUG [[-0.35031 -0.67322 -0.6512   0.46478]
 [ 0.928   -0.34364 -0.14396  0.34975]
 [-0.12686 -0.65474  0.74513  0.41312]
 [ 0.       0.       0.       1.     ]]
jaxsim[100706] DEBUG W_H_F_iDynTree:
jaxsim[100706] DEBUG [[-0.91833 -0.10879 -0.38058 -0.17789]
 [ 0.21158 -0.94752 -0.23966  0.31045]
 [-0.33454 -0.30061  0.89315  1.20324]
 [ 0.       0.       0.       1.     ]]

These transform are complex. Can you try to set the w_H_B to identity and the joint position to zero, and try again? In that condition the W_H_F rotation of all the leg frames should be the identity, making it more easy to debug.

@traversaro
Copy link
Contributor

Right now on ErgoCubReduced the jacobians that are not matching are the ones for the frames:

* l_foot_front

* l_foot_rear

* l_hip_3

* l_shoulder_3

* r_hip_3

* r_shoulder3

These test cases are failing in all 3 velocity representations (inertial, body and fixed). All the other frames are passing the test.

The specificity of all this frames is that they are not leaf "fake link frames" but rather proper links (with an inertia) that are lumped to their parents. Perhaps the lumping is not working as expected either in iDynTree or rod? Do we have a check that simply checks forward kinematics for those frames, instead of checking the jacobian?

Sorry, I read later #148 (comment), this is already happening.

@xela-95
Copy link
Member Author

xela-95 commented May 10, 2024

These transform are complex. Can you try to set the w_H_B to identity and the joint position to zero, and try again? In that condition the W_H_F rotation of all the leg frames should be the identity, making it more easy to debug.

This is a good idea, but I do not know how to change base and joint position programmatically. I'll try to understand how to do this.

@diegoferigo
Copy link
Member

diegoferigo commented May 10, 2024

These transform are complex. Can you try to set the w_H_B to identity and the joint position to zero, and try again? In that condition the W_H_F rotation of all the leg frames should be the identity, making it more easy to debug.

This is a good idea, but I do not know how to change base and joint position programmatically. I'll try to understand how to do this.

If you comment out these lines, by default JaxSimModelData is populated with a trivial orientation and zero data (check zero). Then, you can use the reset* methods to set only parts of the configuration.

@diegoferigo
Copy link
Member

diegoferigo commented May 10, 2024

Now, since this test is passing for the majority of the frames and failing only for 6 of them, what I'm suspecting is that maybe is not always true that the pose of the frame is relative to its parent link. What do you think @diegoferigo?

Maybe, I never tested thoroughly the frame-related logic since this is the first time we are using it. I suspect there might be a bug when the pose of the frames is resolved. This is called when the URDF exported of rod switches to FrameConvention.Urdf here.

@diegoferigo diegoferigo force-pushed the feature/expose-frame-quantities branch from 561923a to 8498c1b Compare May 22, 2024 08:01
@diegoferigo diegoferigo force-pushed the feature/expose-frame-quantities branch from 8498c1b to 61d71a4 Compare May 22, 2024 08:18
@diegoferigo
Copy link
Member

Trying to fix once again in 61d71a4 a regression introduced in this PR related to #103.

@diegoferigo diegoferigo force-pushed the feature/expose-frame-quantities branch from 61d71a4 to a0ed4a8 Compare May 22, 2024 08:23
@diegoferigo
Copy link
Member

diegoferigo commented May 22, 2024

Trying to fix once again in 61d71a4 a regression introduced in this PR related to #103.

Mmh nope it didn't work. We need more time to solve this for good. I propose to proceed with that test disabled so we can start using frames, and fix the problem in another PR. @flferretti

It's worth noting that the test succeeds if I run it locally.

@flferretti
Copy link
Collaborator

It's worth noting that the test succeeds if I run it locally.

Are you using GPU or CPU?

@diegoferigo
Copy link
Member

It's worth noting that the test succeeds if I run it locally.

Are you using GPU or CPU?

Now CPU, but yesterday I tested with GPU and locally was passing either.

@flferretti
Copy link
Collaborator

flferretti commented May 22, 2024

Can you please try to run this and see if you get the same error?

Test Script

import jax.numpy as jnp
import jaxsim.api as js
import rod.builder.primitives
import rod.urdf.exporter
from jaxsim import integrators

# Create on-the-fly a ROD model of a box.
rod_model = (
    rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
    .build_model()
    .add_link()
    .add_inertial()
    .add_visual()
    .add_collision()
    .build()
)

# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
    sdf=rod_model, pretty=True
)

model1 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_string,
    is_urdf=True,
)

model2 = js.model.JaxSimModel.build_from_model_description(
    model_description=urdf_string,
    is_urdf=True,
)

# Build the data
data1 = js.data.JaxSimModelData.build(model=model1)

data2 = js.data.JaxSimModelData.build(model=model2)

# Create the integrators
integrator1 = integrators.fixed_step.Heun2SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model1,
        data=data1,
        system_dynamics=js.ode.system_dynamics,
    ),
)

integrator2 = integrators.fixed_step.Heun2SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model2,
        data=data2,
        system_dynamics=js.ode.system_dynamics,
    ),
)

# ! Try to initialize the integrator
integrator_state1 = integrator1.init(x0=data1.state, t0=0, dt=1e-3)

integrator_state2 = integrator2.init(x0=data2.state, t0=0, dt=1e-3)

@diegoferigo
Copy link
Member

That is instead failing as reported below.

Failure
jaxsim[12474] INFO Enabling JAX to use 64bit precision
jaxsim[12474] DEBUG JAX compilation cache is not supported on CPU
rod[12474] INFO Calling sdformat through '/jaxsim/bin/gz sdf'
rod[12474] DEBUG Building model 'box'
rod[12474] WARNING This method is deprecated, please use 'UrdfExporter.to_urdf_string' instead.
rod[12474] DEBUG Converting model 'box' to URDF
rod[12474] DEBUG Detected 'box_link' as root link
rod[12474] DEBUG Building kinematic tree of model 'box'
rod[12474] DEBUG Selecting 'box_link' as canonical link
rod[12474] DEBUG Node 'world' became a frame attached to 'box_link'
rod[12474] DEBUG Building kinematic tree of model 'box'
rod[12474] DEBUG Selecting 'box_link' as canonical link
rod[12474] DEBUG Node 'world' became a frame attached to 'box_link'
jaxsim[12474] DEBUG Found model 'box' in SDF resource
rod[12474] DEBUG Building kinematic tree of model 'box'
rod[12474] DEBUG Selecting 'box_link' as canonical link
rod[12474] DEBUG Node 'world' became a frame attached to 'box_link'
jaxsim[12474] DEBUG Model 'box' is floating-base
jaxsim[12474] DEBUG Considering 'box_link' as base link
jaxsim[12474] INFO The kinematic graph doesn't need to be reduced
jaxsim[12474] DEBUG Found model 'box' in SDF resource
rod[12474] DEBUG Building kinematic tree of model 'box'
rod[12474] DEBUG Selecting 'box_link' as canonical link
rod[12474] DEBUG Node 'world' became a frame attached to 'box_link'
jaxsim[12474] DEBUG Model 'box' is floating-base
jaxsim[12474] DEBUG Considering 'box_link' as base link
jaxsim[12474] INFO The kinematic graph doesn't need to be reduced
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[1], line 58
     55 # ! Try to initialize the integrator
     56 integrator_state1 = integrator1.init(x0=data1.state, t0=0, dt=1e-3)
---> 58 integrator_state2 = integrator2.init(x0=data2.state, t0=0, dt=1e-3)

File ~/git/jaxsim/src/jaxsim/integrators/common.py:235, in Integrator.init(self, x0, t0, dt, include_dynamics_aux_dict, **kwargs)
    227     integrator.params = {
    228         Integrator.InitializingKey: jnp.array(True),
    229         Integrator.AfterInitKey: jnp.array(False),
    230     }
    232     # Run a dummy call of the integrator.
    233     # It is used only to get the params so that we know the structure
    234     # of the corresponding pytree.
--> 235     _ = integrator(x0, t0, dt, **kwargs)
    237     # integrator.params = {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
    238     # aux_dict_step = integrator.params
    239 
    240 # Remove the injected key.
    241 _ = integrator.params.pop(Integrator.InitializingKey)

File ~/git/jaxsim/src/jaxsim/integrators/common.py:403, in ExplicitRungeKutta.__call__(self, x0, t0, dt, **kwargs)
    398 def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
    399 
    400     # Here z is a batched state with as many batch elements as b.T rows.
    401     # Note that z has multiple batches only if b.T has more than one row,
    402     # e.g. in Butcher tableau of embedded schemes.
--> 403     z = self._compute_next_state(x0=x0, t0=t0, dt=dt, **kwargs)
    405     # The next state is the batch element located at the configured index of solution.
    406     return jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)

File ~/git/jaxsim/src/jaxsim/integrators/common.py:615, in ExplicitRungeKutta._compute_next_state(self, x0, t0, dt, **kwargs)
    612     return carry, None
    614 # Compute the state derivatives kᵢ.
--> 615 K, _ = jax.lax.scan(
    616     f=scan_body,
    617     init=carry0,
    618     xs=jnp.arange(c.size),
    619 )
    621 # Update the FSAL property for the next iteration.
    622 if self.has_fsal:
    623     # print(">>pre", self.params["dxdt0"])
    624     # print(">>pre")
    625     # print(K)

    [... skipping hidden 9 frame]

File ~/git/jaxsim/src/jaxsim/integrators/common.py:601, in ExplicitRungeKutta._compute_next_state.<locals>.scan_body(carry, i)
    564     return f(xi, ti)[0]
    566 # # Define the computation of the Runge-Kutta stage.
    567 # def compute_ki() -> jax.Array:
    568 #     # print(A[i, :].shape)
   (...)
    599 
    600 # This selector enables FSAL property in the first iteration (i=0).
--> 601 ki = jax.lax.cond(
    602     pred=jnp.logical_and(i == 0, self.has_fsal),
    603     true_fun=get_ẋ0,
    604     false_fun=compute_ki,
    605 )
    607 # Store the kᵢ derivative in K.
    608 op = lambda l_k, l_ki: l_k.at[i].set(l_ki)

    [... skipping hidden 11 frame]

File ~/git/jaxsim/src/jaxsim/integrators/common.py:507, in ExplicitRungeKutta._compute_next_state.<locals>.<lambda>()
    490 carry0 = jax.tree_util.tree_map(
    491     lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
    492     x0,
    493 )
    494 # print(carry0)
    495 
    496 # Allocate the parameter to store the FSAL derivative.
   (...)
    505 
    506 # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
--> 507 get_ẋ0 = lambda: self.params.get("dxdt0", f(x0, t0)[0])
    508 # get_ẋ0 = lambda: self.params.get("dxdt0") if self.fsal else f(x0, t0)[0]
    509 # get_ẋ0 = lambda: (
    510 #     self.params["dxdt0"] if self.has_fsal else lambda: f(x0, t0)[0]
   (...)
    517 # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
    518 # would include 4 repetitions of the `f` logic, making everything extremely slow.
    519 def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[jax.Array, None]:

File ~/git/jaxsim/src/jaxsim/integrators/common.py:485, in ExplicitRungeKutta._compute_next_state.<locals>.<lambda>(x, t)
    482 A = self.A
    484 # Close f over optional kwargs.
--> 485 f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
    487 # Initialize the carry of the for loop with the stacked kᵢ vectors.
    488 # carry0 = jnp.zeros(shape=(c.size, x0.size), dtype=float)
    489 # carry0 = jnp.repeat(jnp.zeros_like(x0)[jnp.newaxis, ...], c.size, axis=0)
    490 carry0 = jax.tree_util.tree_map(
    491     lambda l: jnp.repeat(jnp.zeros_like(l)[jnp.newaxis, ...], c.size, axis=0),
    492     x0,
    493 )

File ~/git/jaxsim/src/jaxsim/api/ode.py:94, in wrap_system_dynamics_for_integration.<locals>.f(x, t, **kwargs_f)
     89     data_rw.simulation_time = js.common.Time.build(t_ns=t * 1e9)
     90     # data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
     91 
     92 # Evaluate the system dynamics, allowing to override the kwargs originally
     93 # passed when the closure was created.
---> 94 return system_dynamics(
     95     model=model_f,
     96     data=data_rw,
     97     **(kwargs_closed | kwargs_f),
     98 )

    [... skipping hidden 5 frame]

File /jaxsim/lib/python3.12/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   1474 def error(self, arg):
-> 1475   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1]..
The error occurred while tracing the function <lambda> at /home/dferigo/git/jaxsim/src/jaxsim/integrators/common.py:507 for cond. This value became a tracer due to JAX operations on these lines:

  operation a:bool[1] = eq b c
    from line /home/dferigo/git/jaxsim/src/jaxsim/api/ode.py:94:15 (wrap_system_dynamics_for_integration.<locals>.f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

@flferretti
Copy link
Collaborator

If it is blocking for this PR, feel free to merge, we can address this in the future

@diegoferigo
Copy link
Member

Yeah I guess your script can be a good start to look for a permanent fix (or better, a pattern that would no longer produce this problem). Since your problem was already there, but for some reason the pytree test was not affected, I propose to proceed merging. Let me know if we have applications that are affected (I don't think so).

@diegoferigo diegoferigo force-pushed the feature/expose-frame-quantities branch from c498d6e to 8b5ac05 Compare May 22, 2024 09:02
@flferretti
Copy link
Collaborator

I believe I found a possible reason, check #103 (comment)

@diegoferigo diegoferigo merged commit e341015 into ami-iit:main May 23, 2024
15 checks passed
flferretti added a commit that referenced this pull request May 23, 2024
* Create `jaxsim.api.frame` module with `transform` function

* Add `frame` module to `jaxsim.api` package

* Add unit test for `jaxsim.api.frame` module

* Add index-related functions to `frame` module

* Add `test_frame_index` to `frame` unit tests

* Add `jacobian` method to `frame` module

* Add `test_frame_jacobians` to `frame` unit tests

* Add `.vscode` to gitignore

* Add `frame` module to sphynx documentation

* Apply suggestions from code review

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>

* Update frame.py

* Add additional frame attached to link in box model

* Refactor `test_frame_jacobians` to better debug jacobians not matching with iDynTree

* Exclude from `test_frame_jacobians` the frames that are not loaded in KynDynComputations

* Add `frame_parent_link_name` method to `KinDynComputations class`

* Update code style of `frame.transform` function

* WIP Update `test_frame_transforms` to print parent link frames and not fail at the first failed assertion

* Update `test_frame_transforms`

* Add single pendulum fixture in `conftest.py`

* Clean `test_api_frames`

* Fix retrieval of the frame's parent link index

* Add function to get the frame's parent link index

* Update frames test

* Align link and joint tests

* Removed unused tested model

* Update tests to use new ROD URDF exporter function

* Use plain integers for frame indices

* Add JaxSimModel.frame_names

* Fix regression raising TracerBoolConversionError when comparing pytrees

* Temporarily disable test_pytree

---------

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: diegoferigo <[email protected]>
flferretti added a commit that referenced this pull request Jun 3, 2024
* Create `jaxsim.api.frame` module with `transform` function

* Add `frame` module to `jaxsim.api` package

* Add unit test for `jaxsim.api.frame` module

* Add index-related functions to `frame` module

* Add `test_frame_index` to `frame` unit tests

* Add `jacobian` method to `frame` module

* Add `test_frame_jacobians` to `frame` unit tests

* Add `.vscode` to gitignore

* Add `frame` module to sphynx documentation

* Apply suggestions from code review

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>

* Update frame.py

* Add additional frame attached to link in box model

* Refactor `test_frame_jacobians` to better debug jacobians not matching with iDynTree

* Exclude from `test_frame_jacobians` the frames that are not loaded in KynDynComputations

* Add `frame_parent_link_name` method to `KinDynComputations class`

* Update code style of `frame.transform` function

* WIP Update `test_frame_transforms` to print parent link frames and not fail at the first failed assertion

* Update `test_frame_transforms`

* Add single pendulum fixture in `conftest.py`

* Clean `test_api_frames`

* Fix retrieval of the frame's parent link index

* Add function to get the frame's parent link index

* Update frames test

* Align link and joint tests

* Removed unused tested model

* Update tests to use new ROD URDF exporter function

* Use plain integers for frame indices

* Add JaxSimModel.frame_names

* Fix regression raising TracerBoolConversionError when comparing pytrees

* Temporarily disable test_pytree

---------

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: diegoferigo <[email protected]>
flferretti added a commit that referenced this pull request Jun 13, 2024
* Create `jaxsim.api.frame` module with `transform` function

* Add `frame` module to `jaxsim.api` package

* Add unit test for `jaxsim.api.frame` module

* Add index-related functions to `frame` module

* Add `test_frame_index` to `frame` unit tests

* Add `jacobian` method to `frame` module

* Add `test_frame_jacobians` to `frame` unit tests

* Add `.vscode` to gitignore

* Add `frame` module to sphynx documentation

* Apply suggestions from code review

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>

* Update frame.py

* Add additional frame attached to link in box model

* Refactor `test_frame_jacobians` to better debug jacobians not matching with iDynTree

* Exclude from `test_frame_jacobians` the frames that are not loaded in KynDynComputations

* Add `frame_parent_link_name` method to `KinDynComputations class`

* Update code style of `frame.transform` function

* WIP Update `test_frame_transforms` to print parent link frames and not fail at the first failed assertion

* Update `test_frame_transforms`

* Add single pendulum fixture in `conftest.py`

* Clean `test_api_frames`

* Fix retrieval of the frame's parent link index

* Add function to get the frame's parent link index

* Update frames test

* Align link and joint tests

* Removed unused tested model

* Update tests to use new ROD URDF exporter function

* Use plain integers for frame indices

* Add JaxSimModel.frame_names

* Fix regression raising TracerBoolConversionError when comparing pytrees

* Temporarily disable test_pytree

---------

Co-authored-by: Filippo Luca Ferretti <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: diegoferigo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose quantities related to generic frames similarly to those of link frames
4 participants